from typing import Any
from warnings import warn

import cachetools
import numpy as np
from beartype import beartype
from typing_extensions import Self

ABSENT_TOKEN_LOGPROB = -100.0
LOGPROB_CUTOFF_CLASSIFICATION = -20.0
LOGPROB_CUTOFF_RELATIVE_PRECISION = np.log(0.01)

CACHE_tokenset_from_hf_model = cachetools.LRUCache(maxsize=16)
CACHE_tokenset_from_openai_model = cachetools.LRUCache(maxsize=16)
CACHE_common_lps_with_other = cachetools.LRUCache(maxsize=1024)


@beartype
def logsumexp(a: np.ndarray | list[float] | tuple[float, ...]) -> np.floating:
    a = np.array(a)
    amax = np.max(a)
    return np.log(np.sum(np.exp(a - amax))) + amax


@beartype
class TokenSet:
    def __init__(
        self,
        tokens: list[str] | tuple[str, ...],
        model_name: str = None,
        sort=False,
        make_unique=False,
    ):
        self.tokens = tuple(tokens)
        if make_unique and len(set(self.tokens)) < len(self.tokens):
            self.tokens = tuple(sorted(set(self.tokens)))
        if sort:
            self.tokens = tuple(sorted(self.tokens))
        self.index = {t: i for i, t in enumerate(tokens)}
        self.model_name = model_name

    def __len__(self) -> int:
        return len(self.tokens)

    def __repr__(self) -> str:
        if self.model_name is not None:
            return f"TokenSet({self.model_name!r}, {len(self.tokens)} tokens)"
        if len(self) <= 5:
            return f"TokenSet({self.tokens!r}, {len(self.tokens)} tokens)"
        return f"TokenSet({self.tokens[:5]!r} ..., {len(self.tokens)} tokens)"

    @classmethod
    def from_model(cls, model_name_or_path: str) -> Self:
        from .queries import is_openai_model

        if not is_openai_model(model_name_or_path):
            return _tokenset_from_hf_model(model_name_or_path)
        else:
            return _tokenset_from_openai_model(model_name_or_path)

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, TokenSet):
            return False
        if self.model_name != other.model_name:
            return False
        if self.model_name is not None and self.model_name == other.model_name:
            return True
        return self.tokens == other.tokens  # Slow, should only happen for small ones

    def tokenset_union(self, other: Self) -> Self:
        if self.model_name is not None and self.model_name == other.model_name:
            return self
        if self.tokens == other.tokens:
            return self
        if self.model_name is not None or other.model_name is not None:
            warn.warn(f"Combining TokenSets of different models: {self!r} with {other!r}")
        return TokenSet(sorted(set(self.tokens) | set(other.tokens)))


@cachetools.cached(CACHE_tokenset_from_hf_model)
def _tokenset_from_hf_model(model_name_or_path: str) -> TokenSet:
    from transformers import AutoTokenizer, GPT2Tokenizer

    # if model_name_or_path.startswith("gpt2"):
    #    tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
    # else:
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
    # TODO: include extra/additional tokens of the tokenizers (needed for e.g. Orca)
    if model_name_or_path.startswith("gpt2"):
        tokens = [
            tokenizer.convert_tokens_to_string([t])
            for t in tokenizer.convert_ids_to_tokens(list(range(tokenizer.vocab_size)))
        ]
    else:
        tokens = [t.replace("▁", " ") for t in tokenizer.convert_ids_to_tokens(list(range(tokenizer.vocab_size)))]
    return TokenSet(tokens, model_name=model_name_or_path)


@cachetools.cached(CACHE_tokenset_from_openai_model)
def _tokenset_from_openai_model(model_name: str) -> TokenSet:
    import tiktoken

    e = tiktoken.encoding_for_model(model_name)
    ### HACK! But should be ok
    return TokenSet(e.decode_batch([[i] for i in range(e.n_vocab - 500)]), model_name=model_name)


@beartype
class TokenLogprobs:
    def __init__(
        self,
        tokens: TokenSet | list[str],
        logprobs: np.ndarray,
        others_logprob: float | np.floating = None,
        sort=False,
    ):
        self.logprobs = np.array(logprobs, dtype=np.float32)
        assert self.logprobs.shape == (len(tokens),)

        if others_logprob is not None:
            self.others_logprob = others_logprob
        else:
            rem_p = np.maximum(np.exp(ABSENT_TOKEN_LOGPROB), 1.0 - np.sum(np.exp(self.logprobs)))

            if rem_p < np.exp(ABSENT_TOKEN_LOGPROB):
                self.others_logprob = ABSENT_TOKEN_LOGPROB
            else:
                self.others_logprob = np.log(rem_p)

        if sort:
            if isinstance(tokens, TokenSet):
                if tokens.model_name is not None:
                    warn.warn(f"Sorting TokenLogprobs of model {tokens.model_name!r} is likely very slow!")
                tokens = tokens.tokens
            sort_idx = np.argsort(self.logprobs)[::-1]
            self.logprobs = self.logprobs[sort_idx]
            tokens = [tokens[i] for i in sort_idx]

        if not isinstance(tokens, TokenSet):
            tokens = TokenSet(tokens)
        self.tokenset = tokens

    @property
    def tokens(self):
        return self.tokenset.tokens

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, token: str) -> float | np.floating:
        if token in self.tokenset.index:
            return self.logprobs[self.tokenset.index[token]]
        return self.others_logprob

    def with_temperature(self, temperature: float | np.floating | None) -> Self:
        if temperature is None or np.isclose(temperature, 1.0):
            return self
        elif np.isclose(temperature, 0.0):
            imax = np.argmax(self.logprobs)
            logprobs = np.full_like(self.logprobs, ABSENT_TOKEN_LOGPROB)
            logprobs[imax] = 0.0
            return TokenLogprobs(self.tokenset, logprobs, others_logprob=ABSENT_TOKEN_LOGPROB)
        else:
            raise ValueError(f"Use only temperature 0.0 and 1.0 (None) for stable results")
        # lps = self.logprobs / temperature
        # other_lp = self.others_logprob / temperature
        # lps = lps - logsumexp(np.concatenate((lps, [other_lp])))
        # return TokenLogprobs(self.tokenset, lps)

    def top_tokens(self, n=None, min_logprob=None) -> Self:
        if n is None and min_logprob is None:
            return self
        if min_logprob is not None:
            idx = [self.logprobs >= min_logprob].nonzero()
        else:
            idx = np.arange(len(self))
        idx = idx[np.argsort(self.logprobs[idx])[::-1]]
        if n is not None:
            idx = idx[:n]
        return TokenLogprobs(
            [self.tokenset.tokens[i] for i in idx],
            self.logprobs[idx],
        )

    def logprobs_for_tokenset(self, tokenset: TokenSet, others_logprob=None, upper_bound=False) -> np.ndarray:
        def g(t):
            if t in self.tokenset.index:
                return self.logprobs[self.tokenset.index[t]]
            if others_logprob is None:
                return self.others_logprob
            return others_logprob

        if tokenset == self.tokenset:
            res = self.logprobs
        else:
            res = np.array([g(t) for t in tokenset.tokens], dtype=np.float32)
        if upper_bound:
            l = logsumexp(res)
            if l > 0.0:
                res = res - l
        return res

    def __repr__(self) -> str:
        if len(self.tokenset) <= 5:
            return f"TokenLogprobs({self.tokenset!r}, {self.logprobs!r})"
        return f"TokenLogprobs({self.tokenset!r}, {self.logprobs[:5]!r} ...)"

    def extrapolate_from_weakened(self, other: Self, alpha: float | np.floating) -> Self:
        """
        alpha=0.0 -> self (approx; with joint token set)
        alpha=-1.0 -> other (approx; with joint token set)
        """

        tokenset, tokenset_ext, s_lps, o_lps = _common_lps_with_other(self, other)
        oti = tokenset_ext.index["<other_tokens>"]

        # Interpolation including "other tokens" treated as a single, special token
        lps = -alpha * o_lps + (1.0 + alpha) * s_lps

        lps[oti] = min(lps[oti], self.others_logprob, np.min(self.logprobs))  ## ??
        for t in tokenset.tokens:
            if t not in self.tokenset.index:
                lps[tokenset_ext.index[t]] = min(
                    lps[tokenset_ext.index[t]], o_lps[tokenset_ext.index[t]], self.others_logprob, np.min(self.logprobs)
                )

        lps = lps - logsumexp(lps)
        return TokenLogprobs(tokenset, TokenLogprobs(tokenset_ext, lps).logprobs_for_tokenset(tokenset))

    def tvd(self, other: Self) -> float:
        """Total Variation Norm"""
        _tokenset, _tokenset_ext, s_lps, o_lps = _common_lps_with_other(self, other)
        return 0.5 * np.sum(np.abs(np.exp(s_lps) - np.exp(o_lps)))

    # def interpolate_to(self, other: Self, other_weight: float | np.floating) -> Self:
    #     """Naive linear interpolation with normalization.

    #     other_weight=0.0 -> self (with joint token set)
    #     other_weight=1.0 -> other (with joint token set)

    #     TODO: Improve!"""
    #     tokenset = self.tokenset.tokenset_union(other.tokenset)
    #     tokenset_ext = tokenset.tokenset_union(TokenSet(["<special_token>"]))
    #     s_lps = self.logprobs_for_tokenset(tokenset_ext)
    #     o_lps = other.logprobs_for_tokenset(tokenset_ext)
    #     lps = other_weight * o_lps + (1.0 - other_weight) * s_lps
    #     lps = lps - logsumexp(lps)
    #     return TokenLogprobs(
    #         tokenset, TokenLogprobs(tokenset_ext, lps).logprobs_for_tokenset(tokenset)
    #     )


@cachetools.cached(CACHE_common_lps_with_other)
def _common_lps_with_other(s: TokenLogprobs, o: TokenLogprobs) -> tuple[TokenSet, TokenSet, np.ndarray, np.ndarray]:
    """
    alpha=0.0 -> self (approx; with joint token set)
    alpha=-1.0 -> other (approx; with joint token set)
    """

    tokenset = s.tokenset.tokenset_union(o.tokenset)
    tokenset_ext = tokenset.tokenset_union(TokenSet(["<other_tokens>"]))
    oti = tokenset_ext.index["<other_tokens>"]

    s_lps = s.logprobs_for_tokenset(
        tokenset_ext, upper_bound=True, others_logprob=np.minimum(s.others_logprob, np.min(s.logprobs))
    )
    assert logsumexp(s_lps) < 1e-4, f"{logsumexp(s_lps)=}"
    s_lps[oti] = ABSENT_TOKEN_LOGPROB
    s_lps[oti] = np.log(np.maximum(np.exp(ABSENT_TOKEN_LOGPROB), 1.0 - np.sum(np.exp(s_lps))))
    assert np.isclose(logsumexp(s_lps), 0.0, atol=1e-4), f"{logsumexp(s_lps)=} {s_lps=} {oti=}"

    o_lps = o.logprobs_for_tokenset(
        tokenset_ext, upper_bound=True, others_logprob=np.minimum(o.others_logprob, np.min(o.logprobs))
    )
    assert logsumexp(o_lps) < 1e-4, f"{logsumexp(o_lps)=}"
    o_lps[oti] = ABSENT_TOKEN_LOGPROB
    o_lps[oti] = np.log(np.maximum(np.exp(ABSENT_TOKEN_LOGPROB), 1.0 - np.sum(np.exp(o_lps))))
    assert np.isclose(logsumexp(o_lps), 0.0, atol=1e-4), f"{logsumexp(o_lps)=} {o_lps=} {oti=}"

    return tokenset, tokenset_ext, s_lps, o_lps
